import os
import numpy as np
import os.path as osp
import matplotlib.pyplot as plt
import cv2

def concat_imgs(img_dir, depth_dir, out_dir):
    if not osp.exists(out_dir):
        os.makedirs(out_dir)
    files = os.listdir(depth_dir)
    for f in files:
        timestamp = f.strip("_pred_depth.png")
        img_file = osp.join(img_dir, timestamp + '.png')
        depth_file = osp.join(depth_dir, timestamp + '_pred_depth.png')
        out_file = osp.join(out_dir, timestamp + '.png')
        img = cv2.imread(img_file)
        depth = cv2.imread(depth_file, -1)
        depth = depth[..., None].repeat(3, 2)
        plt.subplot(121)
        plt.imshow(img)
        plt.subplot(122)
        plt.imshow(depth)
        plt.savefig(out_file)
        plt.clf()
        # concat_img = np.hstack([img, depth])
        # cv2.imwrite(out_file, concat_img)
        
        

if __name__ == "__main__":
    img_dir = "/cv/yc/DSGN2/data/ww/training/image_2"
    depth_dir = "/cv/yc/DSGN2/img_bbox3d/ww_val_bev_stereo1019/pred_depths_masks"
    out_dir = "/cv/yc/DSGN2/img_bbox3d/ww_val_bev_stereo1019/concats"
    concat_imgs(img_dir, depth_dir, out_dir)